{
 "cells": [
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "丢弃法\n",
    "dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "from mxnet import autograd, gluon, init, nd\n",
    "from mxnet.gluon import loss as gloss, nn\n",
    "\n",
    "def dropout(X, drop_prob):\n",
    "    assert 0 <= drop_prob <= 1\n",
    "    keep_prob = 1 - drop_prob\n",
    "    # 这种情况下把全部元素都丢弃\n",
    "    if keep_prob == 0:\n",
    "        return X.zeros_like()\n",
    "    mask = nd.random.uniform(0, 1, X.shape) < keep_prob\n",
    "    return mask * X / keep_prob"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "测试dropout函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[ 0.  1.  2.  3.  4.  5.  6.  7.]\n",
       " [ 8.  9. 10. 11. 12. 13. 14. 15.]]\n",
       "<NDArray 2x8 @cpu(0)>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = nd.arange(16).reshape((2, 8))\n",
    "dropout(X, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[ 0.  0.  0.  6.  0. 10. 12.  0.]\n",
       " [ 0.  0. 20. 22.  0.  0. 28. 30.]]\n",
       "<NDArray 2x8 @cpu(0)>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dropout(X, 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "[[0. 0. 0. 0. 0. 0. 0. 0.]\n",
       " [0. 0. 0. 0. 0. 0. 0. 0.]]\n",
       "<NDArray 2x8 @cpu(0)>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dropout(X, 1)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256\n",
    "\n",
    "W1 = nd.random.normal(scale=0.01, shape=(num_inputs, num_hiddens1))\n",
    "b1 = nd.zeros(num_hiddens1)\n",
    "W2 = nd.random.normal(scale=0.01, shape=(num_hiddens1, num_hiddens2))\n",
    "b2 = nd.zeros(num_hiddens2)\n",
    "W3 = nd.random.normal(scale=0.01, shape=(num_hiddens2, num_outputs))\n",
    "b3 = nd.zeros(num_outputs)\n",
    "\n",
    "params = [W1, b1, W2, b2, W3, b3]\n",
    "for param in params:\n",
    "    param.attach_grad()"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "drop_prob1, drop_prob2 = 0.6, 0.1\n",
    "\n",
    "def net(X):\n",
    "    X = X.reshape((-1, num_inputs))\n",
    "    H1 = (nd.dot(X, W1) + b1).relu()\n",
    "    if autograd.is_training():  # 只在训练模型时使用丢弃法\n",
    "        H1 = dropout(H1, drop_prob1)  # 在第一层全连接后添加丢弃层\n",
    "    H2 = (nd.dot(H1, W2) + b2).relu()\n",
    "    if autograd.is_training():\n",
    "        H2 = dropout(H2, drop_prob2)  # 在第二层全连接后添加丢弃层\n",
    "    return nd.dot(H2, W3) + b3"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "训练和测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 1.1653, train acc 0.546, test acc 0.757\n",
      "epoch 2, loss 0.6119, train acc 0.768, test acc 0.826\n",
      "epoch 3, loss 0.5312, train acc 0.804, test acc 0.844\n",
      "epoch 4, loss 0.4942, train acc 0.818, test acc 0.849\n",
      "epoch 5, loss 0.4679, train acc 0.828, test acc 0.859\n",
      "epoch 6, loss 0.4469, train acc 0.835, test acc 0.855\n",
      "epoch 7, loss 0.4305, train acc 0.842, test acc 0.866\n",
      "epoch 8, loss 0.4175, train acc 0.845, test acc 0.864\n",
      "epoch 9, loss 0.4084, train acc 0.851, test acc 0.867\n",
      "epoch 10, loss 0.3984, train acc 0.855, test acc 0.870\n",
      "epoch 11, loss 0.3927, train acc 0.854, test acc 0.871\n",
      "epoch 12, loss 0.3871, train acc 0.858, test acc 0.876\n",
      "epoch 13, loss 0.3821, train acc 0.860, test acc 0.876\n",
      "epoch 14, loss 0.3774, train acc 0.861, test acc 0.875\n",
      "epoch 15, loss 0.3698, train acc 0.863, test acc 0.873\n",
      "epoch 16, loss 0.3661, train acc 0.865, test acc 0.877\n",
      "epoch 17, loss 0.3601, train acc 0.867, test acc 0.880\n",
      "epoch 18, loss 0.3565, train acc 0.868, test acc 0.881\n",
      "epoch 19, loss 0.3569, train acc 0.868, test acc 0.877\n",
      "epoch 20, loss 0.3507, train acc 0.870, test acc 0.881\n",
      "epoch 21, loss 0.3451, train acc 0.873, test acc 0.881\n",
      "epoch 22, loss 0.3476, train acc 0.872, test acc 0.880\n",
      "epoch 23, loss 0.3409, train acc 0.873, test acc 0.882\n",
      "epoch 24, loss 0.3399, train acc 0.874, test acc 0.883\n",
      "epoch 25, loss 0.3371, train acc 0.875, test acc 0.887\n",
      "epoch 26, loss 0.3310, train acc 0.876, test acc 0.883\n",
      "epoch 27, loss 0.3335, train acc 0.875, test acc 0.884\n",
      "epoch 28, loss 0.3294, train acc 0.877, test acc 0.886\n",
      "epoch 29, loss 0.3275, train acc 0.879, test acc 0.882\n",
      "epoch 30, loss 0.3257, train acc 0.879, test acc 0.887\n",
      "epoch 31, loss 0.3221, train acc 0.879, test acc 0.882\n",
      "epoch 32, loss 0.3224, train acc 0.881, test acc 0.888\n",
      "epoch 33, loss 0.3184, train acc 0.882, test acc 0.884\n",
      "epoch 34, loss 0.3160, train acc 0.883, test acc 0.891\n",
      "epoch 35, loss 0.3156, train acc 0.883, test acc 0.888\n",
      "epoch 36, loss 0.3116, train acc 0.885, test acc 0.885\n",
      "epoch 37, loss 0.3123, train acc 0.882, test acc 0.888\n",
      "epoch 38, loss 0.3085, train acc 0.884, test acc 0.884\n",
      "epoch 39, loss 0.3102, train acc 0.885, test acc 0.891\n",
      "epoch 40, loss 0.3065, train acc 0.886, test acc 0.891\n",
      "epoch 41, loss 0.3041, train acc 0.886, test acc 0.883\n",
      "epoch 42, loss 0.3028, train acc 0.886, test acc 0.887\n",
      "epoch 43, loss 0.3007, train acc 0.888, test acc 0.889\n",
      "epoch 44, loss 0.3007, train acc 0.888, test acc 0.890\n",
      "epoch 45, loss 0.3006, train acc 0.887, test acc 0.890\n",
      "epoch 46, loss 0.2996, train acc 0.888, test acc 0.885\n",
      "epoch 47, loss 0.2979, train acc 0.890, test acc 0.890\n",
      "epoch 48, loss 0.2937, train acc 0.891, test acc 0.886\n",
      "epoch 49, loss 0.2939, train acc 0.889, test acc 0.892\n",
      "epoch 50, loss 0.2938, train acc 0.890, test acc 0.888\n"
     ]
    }
   ],
   "source": [
    "num_epochs, lr, batch_size = 50, 0.5, 256\n",
    "loss = gloss.SoftmaxCrossEntropyLoss()\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,\n",
    "              params, lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
